import torch
import torch.nn as nn
import torch.nn.functional as F

bottleneck_size = 2 # n means, n log variances
drop = 0


class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.lin1 = nn.Linear(784, 512)
        self.lin2 = nn.Linear(512, 128)
        self.lin3 = nn.Linear(128, 64)
        self.mean_head = nn.Linear(64, bottleneck_size)
        self.std_head  = nn.Linear(64, bottleneck_size)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.dropout(self.lin1(x), p=drop, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.lin2(x), p=drop, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.lin3(x), p=drop, training=self.training)
        x = F.relu(x)
        return self.mean_head(x), self.std_head(x)

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.lin1 = nn.Linear(bottleneck_size, 64)
        self.lin2 = nn.Linear(64, 128)
        self.lin3 = nn.Linear(128, 512)
        self.lin4 = nn.Linear(512, 784)
        self.sig  = nn.Sigmoid()

    def forward(self, z):
        x = F.dropout(self.lin1(z), p=drop, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.lin2(x), p=drop, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.lin3(x), p=drop, training=self.training)
        x = F.relu(x)
        return self.sig(self.lin4(x))

class Discriminator(nn.Module):
    # For the discriminator, we not only want to
    # give a prediction as to whether the image is real (1)
    # or fake (0), but also give image features from an intermediate layer

    def __init__(self):
        super(Discriminator, self).__init__()
        self.lin1 = nn.Linear(784, 256)
        self.lin2 = nn.Linear(256, 64)
        self.lin3 = nn.Linear(64, 1)
        self.sig  = nn.Sigmoid()

    def forward(self, x): # Given 784 either real or fake
        x = F.dropout(self.lin1(x), p=drop, training=self.training)
        x = F.dropout(self.lin2(x), p=drop, training=self.training)

        y_hat = self.sig(self.lin3(x))
        return y_hat